[PyTorch] Support single parameter for GroupedLinear#2731
[PyTorch] Support single parameter for GroupedLinear#2731ksivaman merged 6 commits intoNVIDIA:mainfrom
GroupedLinear#2731Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch L0 |
Greptile SummaryThis PR introduces a Core changes:
Minor issue found: A duplicate comment line in The core refactoring of the storage class, API cleanup, and C++ dispatch update is sound, with appropriate assertion guards against incompatible quantizer configurations. Confidence Score: 3/5
Sequence DiagramsequenceDiagram
participant GL as GroupedLinear.__init__
participant RP as reset_parameters
participant MGW as make_grouped_weights
participant GTS as GroupedTensorStorage.make_grouped_tensor
participant GT as GroupedTensor (wrapper)
participant GTS2 as GroupedTensorStorage (internal)
GL->>RP: single_grouped_parameter=True
RP->>MGW: make_grouped_weights(defer_init)
MGW->>GTS: make_grouped_tensor_with_shapes(shapes, quantizer)
alt quantizer.internal == False (or quantizer is None)
GTS->>GT: GroupedTensor(shape, dtype, num_tensors, ...)
GT-->>MGW: grouped_weights (torch.Tensor subclass)
MGW->>GL: register_parameter("weight", nn.Parameter(grouped_weights))
else quantizer.internal == True
GTS->>GTS2: GroupedTensorStorage(shape, dtype, num_tensors, ...)
GTS2-->>MGW: grouped_weights (plain Python object)
MGW->>GL: assert fails → error raised
end
MGW->>GL: register_parameter("weight{i}", None) for each GEMM
note over GT: GroupedTensor.__torch_dispatch__ intercepts ops,<br/>dequantizes members → stacked tensor → op → requantizes
Last reviewed commit: ef58675 |
| # Re-register as a single grouped weight parameter. | ||
| self.register_parameter( | ||
| "weight", | ||
| torch.nn.Parameter(grouped_weights), | ||
| init_fn=self.init_method, | ||
| get_rng_state_tracker=self.get_rng_state_tracker, | ||
| fp8_meta_index=self._offsets["weight"], | ||
| ) |
There was a problem hiding this comment.
nn.Parameter wrapping will crash if quantizer.internal=True
GroupedTensorStorage.make_grouped_tensor() branches on quantizer.internal: when True it returns a plain GroupedTensorStorage, which is not a torch.Tensor subclass. Passing that object to torch.nn.Parameter(...) will raise a TypeError at runtime.
The early-return guard at line 766 covers delayed() and float8_current_scaling(), but not MXFP8, Float8BlockScaling, or NVFP4 with an internal quantizer. If any of those quantizers are used as weight quantizers with internal=True, this line will crash:
self.register_parameter("weight", torch.nn.Parameter(grouped_weights), ...)Consider adding an explicit assertion before this call:
| # Re-register as a single grouped weight parameter. | |
| self.register_parameter( | |
| "weight", | |
| torch.nn.Parameter(grouped_weights), | |
| init_fn=self.init_method, | |
| get_rng_state_tracker=self.get_rng_state_tracker, | |
| fp8_meta_index=self._offsets["weight"], | |
| ) | |
| # Re-register as a single grouped weight parameter. | |
| assert isinstance(grouped_weights, torch.Tensor), ( | |
| "single_grouped_parameter requires a GroupedTensor (torch.Tensor subclass); " | |
| "got GroupedTensorStorage (quantizer.internal=True is unsupported here)." | |
| ) | |
| self.register_parameter( | |
| "weight", | |
| torch.nn.Parameter(grouped_weights), | |
| init_fn=self.init_method, | |
| get_rng_state_tracker=self.get_rng_state_tracker, | |
| fp8_meta_index=self._offsets["weight"], | |
| ) |
| def __repr__(self) -> str: | ||
| """String representation of the GroupedTensor.""" | ||
| return ( | ||
| f"GroupedTensor(num_tensors={self.num_tensors}, " | ||
| f"shape={self.shape}, " | ||
| f"shapes={self.tensor_shapes}, " | ||
| f"logical_shape={self.logical_shape}, " | ||
| f"quantizer={self.quantizer}, " | ||
| f"dtype={self.get_dtype()})" | ||
| ) |
There was a problem hiding this comment.
Stale class name in __repr__ output
After the rename of this class from GroupedTensor to GroupedTensorStorage, the repr string still emits GroupedTensor(...). This makes it confusing to distinguish the storage object from the new GroupedTensor wrapper when debugging.
| def __repr__(self) -> str: | |
| """String representation of the GroupedTensor.""" | |
| return ( | |
| f"GroupedTensor(num_tensors={self.num_tensors}, " | |
| f"shape={self.shape}, " | |
| f"shapes={self.tensor_shapes}, " | |
| f"logical_shape={self.logical_shape}, " | |
| f"quantizer={self.quantizer}, " | |
| f"dtype={self.get_dtype()})" | |
| ) | |
| def __repr__(self) -> str: | |
| """String representation of the GroupedTensorStorage.""" | |
| return ( | |
| f"GroupedTensorStorage(num_tensors={self.num_tensors}, " | |
| f"shapes={self.tensor_shapes}, " | |
| f"logical_shape={self.logical_shape}, " | |
| f"quantizer={self.quantizer}, " | |
| f"dtype={self.get_dtype()})" | |
| ) |
| @@ -314,20 +296,20 @@ def make_grouped_tensor_with_shapes( | |||
| """ | |||
There was a problem hiding this comment.
Return type annotation is too narrow
make_grouped_tensor_with_shapes() (and make_grouped_tensor() at line 333) are annotated as returning GroupedTensorStorage, but they actually return a GroupedTensor (a torch.Tensor subclass) when quantizer.internal is False — which is the common case for user-facing weight parameters.
Looking at lines 564–569 of make_grouped_tensor():
internal = False if quantizer is None else quantizer.internal
if internal:
grouped_tensor_class = GroupedTensorStorage
else:
from ..grouped_tensor import GroupedTensor
grouped_tensor_class = GroupedTensorCallers in grouped_linear.py wrap the return value in torch.nn.Parameter, which only works for torch.Tensor subclasses. The annotation does not convey this requirement and will mislead type-checkers.
Consider updating the return type annotation to Union[GroupedTensorStorage, GroupedTensor] or adding a note to the docstring clarifying that the returned type may be a GroupedTensor when quantizer.internal=False.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
| set_tensor_model_parallel_attributes( | ||
| tensor=getattr(self, f"weight{i}"), | ||
| tensor=grouped_weight, | ||
| is_parallel=True, | ||
| dim=1 if self.parallel_mode == "row" else 0, | ||
| stride=1, | ||
| ) |
There was a problem hiding this comment.
Wrong partition_dim for 3D grouped weight tensor
When single_grouped_parameter=True, the grouped weight has shape [num_gemms, out_features, in_features]. The same dim values used for the per-GEMM 2D weights (out_features, in_features) are reused here without adjustment:
"row"parallel →dim=1partitions alongout_features— but it should partition alongin_features→dim=2"column"parallel →dim=0partitions alongnum_gemms— but it should partition alongout_features→dim=1
This causes the wrong axis to be sharded, breaking any tensor-parallel run that uses single_grouped_parameter=True.
| set_tensor_model_parallel_attributes( | |
| tensor=getattr(self, f"weight{i}"), | |
| tensor=grouped_weight, | |
| is_parallel=True, | |
| dim=1 if self.parallel_mode == "row" else 0, | |
| stride=1, | |
| ) | |
| set_tensor_model_parallel_attributes( | |
| tensor=grouped_weight, | |
| is_parallel=True, | |
| dim=2 if self.parallel_mode == "row" else 1, | |
| stride=1, | |
| ) |
| super().__torch_dispatch__(func, types, new_args, new_kwargs) | ||
| for arg, new_arg, schema_arg in zip(args, new_args, schema_args): | ||
| maybe_update_inplace(arg, new_arg, schema_arg) | ||
| for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): | ||
| assert kwarg == new_kwarg == schema_arg.name, "name of kwarg should match schema" | ||
| maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) | ||
| return None | ||
|
|
||
| # Default op: operate on dequantized stacked tensors. | ||
| new_args = tree_map(maybe_unwrap, args) | ||
| new_kwargs = tree_map(maybe_unwrap, kwargs) | ||
| return super().__torch_dispatch__(func, types, new_args, new_kwargs) |
There was a problem hiding this comment.
super().__torch_dispatch__ passes original types containing GroupedTensor
Both the in-place and default dispatch paths call:
super().__torch_dispatch__(func, types, new_args, new_kwargs)At this point new_args has already been unwrapped (all GroupedTensor instances replaced with plain stacked tensors), but types still contains GroupedTensor. PyTorch's C++ dispatch layer examines types when deciding whether to re-dispatch; passing the original types while the actual tensor arguments are plain tensors can cause the dispatch to call GroupedTensor.__torch_dispatch__ again, leading to infinite recursion.
The idiomatic pattern for a wrapper subclass that has already substituted all its arguments is to call the op directly:
# In-place path:
func(*new_args, **new_kwargs)
# Default path:
return func(*new_args, **new_kwargs)This avoids any re-dispatch via types and directly executes the ATen kernel on the unwrapped tensors.
| for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): | ||
| assert kwarg == new_kwarg == schema_arg.name, "name of kwarg should match schema" | ||
| maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) |
There was a problem hiding this comment.
Kwargs-to-schema alignment is fragile
The loop
for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]):
assert kwarg == new_kwarg == schema_arg.name, "name of kwarg should match schema"makes two implicit assumptions that can break:
-
Order:
kwargs(andnew_kwargs) enumerate their keys in insertion order, butschema_args[args_len:]lists all remaining schema arguments, including any that were not actually passed as kwargs. If the caller omits an optional argument that appears before the first actual kwarg in the schema, the zip would pair them incorrectly, triggering the assertion. -
Coverage:
zipsilently stops at the shortest iterable, so in-place writeback for kwargs that appear later in the schema than the number of passed kwargs is silently skipped.
A safer approach is to match kwargs by name against the schema:
schema_arg_by_name = {a.name: a for a in schema_args[args_len:]}
for kwarg in kwargs:
schema_arg = schema_arg_by_name.get(kwarg)
if schema_arg is not None:
maybe_update_inplace(kwargs[kwarg], new_kwargs[kwarg], schema_arg)| return arg | ||
|
|
||
| def update_grouped_tensor_inplace(grouped: GroupedTensor, updated: torch.Tensor): | ||
| if not grouped.all_same_shape(): |
There was a problem hiding this comment.
Since this is also repeated in grouped_to_stacked_tensor. might make sense to move this to the start of function and do it once?
There was a problem hiding this comment.
They are 2 separate checks but the execution only happens once, either for the inplace case, or the normal case.
There was a problem hiding this comment.
Oh I was saying from code duplication perspective. Anyway this is minor. Will leave it upto you
| PyTypeObject *NVFP4TensorPythonClass = nullptr; | ||
| PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; | ||
| PyTypeObject *NVFP4QuantizerClass = nullptr; | ||
| std::once_flag extension_init_flag; |
There was a problem hiding this comment.
This flag is used below to initialize all extensions once
There was a problem hiding this comment.
Yep, just moved 2 lines below
|
|
||
| strides = [1] * len(wrapper_shape) | ||
| for i in range(len(wrapper_shape) - 2, -1, -1): | ||
| strides[i] = strides[i + 1] * wrapper_shape[i + 1] |
There was a problem hiding this comment.
In other classes, I have recently added stride as an argument to constructor.
This allows to create class from C++ with lesser CPU overhead. Any python compute code in new has a lot of CPU overhead.
This logic can be done only if stride is not provided (whih will be object creation in python)
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
| # Parameter construction calls detach()/alias-like paths. | ||
| if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default): | ||
| return args[0] |
There was a problem hiding this comment.
detach short-circuit returns the original GroupedTensor without detaching
Returning args[0] directly for detach.default and alias.default means the returned tensor is the same object as the original — it is not a new detached view. This is intentional for wrapper subclasses in some cases, but it has a subtle consequence: torch.nn.Parameter(grouped_weights) internally calls detach() and expects a logically detached tensor. Returning the original GroupedTensor means there is only one object being both the raw tensor and the nn.Parameter backing storage, which can confuse gradient tracking and .data access patterns.
A more conventional pattern for wrapper subclasses is to create a shallow clone via _make_wrapper_subclass to represent the detached view, preserving the same storage but returning a distinct object:
if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default):
out = args[0].__class__.__new__(args[0].__class__, ...)
# copy over GroupedTensorStorage state fields
return outAt minimum, consider documenting why returning args[0] directly is safe here.
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
for more information, see https://pre-commit.ci
| if grouped.quantizer is None: | ||
| grouped_members = grouped.quantized_tensors | ||
| if grouped_members is None: | ||
| grouped_members = grouped.split_into_quantized_tensors() |
There was a problem hiding this comment.
add some fall back logic here:
- if grouped quantize is available
- if it's not available, split it and trigger quantize one by one
| ) | ||
| grouped_members = grouped.quantized_tensors | ||
| if grouped_members is None: | ||
| grouped_members = grouped.split_into_quantized_tensors() |
There was a problem hiding this comment.
same comment, add some fallback check
| # Re-register as a single grouped weight parameter. | ||
| # Re-register as a single grouped weight parameter. |
There was a problem hiding this comment.
Duplicate comment line
The comment # Re-register as a single grouped weight parameter. appears on both line 789 and 790 — this is a copy-paste artifact. Remove one of them.
| # Re-register as a single grouped weight parameter. | |
| # Re-register as a single grouped weight parameter. | |
| # Re-register as a single grouped weight parameter. | |
| assert isinstance(grouped_weights, torch.Tensor) and ( |
zhongbozhu
left a comment
There was a problem hiding this comment.
Mostly LGTM, need to add some logic when using split_into_quantized_tensors since this API is not going to be performant, we shouldn't need to split and then quantize when a grouped quantize kernel for weight is ready.
|
/te-ci pytorch L0 |
| def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: | ||
| """Get the weight tensors of the module.""" | ||
| weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] | ||
| grouped_weight = getattr(self, "weight", None) |
There was a problem hiding this comment.
there is another case:
- suppose fp8 primary weight is already supported for grouped single weight, then you also shouldn't split it because this means that you can directly call the gemm
- suppose we didn't turn on fp8 primary weight, but still turn on single weight, then we should 1) if grouped quantize kernel for weight is spported, call it 2) if it's not, then call this API to split and quantize it, so it's actually better to call a split_quantize instead which is more performant.
This basically means that when this _get_weight_tensors is getting called, it's not going to try to split it beforehand, but we should instead try to call grouped_quantize or split_quantize here:
Description
Support option to expose single parameter for
GroupedLinearmodule.Type of change
Changes
GroupedLinearmodule.Checklist: